package com.uxsino.commons.utils;


import javax.crypto.Cipher;
import java.io.ByteArrayOutputStream;
import java.security.*;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.util.Arrays;
import java.util.Base64;
import java.util.stream.Stream;

/**
 * RSA 加密/解密
 * @author RanXC
 */
public class RSA {
    public static final String CHARSET =  "UTF-8";

    private String charset =  CHARSET;

    private String algorithm = "RSA";

    private String signAlgorithm = "MD5withRSA";

    private String decodeprovider = "SunJCE";

    private String keyProvider = "SunRsaSign";

    private String encodeProvider = "SunRsaSign";
    private int keySize = 1024;//  512 * 2 的幂次方

    private RSA(){
        java.security.Security.addProvider(
                new org.bouncycastle.jce.provider.BouncyCastleProvider()
        );
    }

    public static RSA of(){
        return new RSA();
    }

    public RSA keySize(int keySize){
        if (keySize < 512 || keySize > 4096 || ((keySize < 1024) && keySize % 64 != 0) || (keySize >= 1024 && keySize % 1024 != 0))
        {
            throw new InvalidParameterException("strength must be from 512 - 4096 and a multiple of 1024 above 1024");
        }
        this.keySize = keySize;
        return this;
    }

    /**
     * 加密算法设置
     * @param algorithm
     * @return
     * @throws NoSuchAlgorithmException
     */
    public RSA algorithm(String algorithm) throws NoSuchAlgorithmException {
        this.algorithm = algorithm;
        return this;
    }

    /**
     * 签名算法设置
     * @param algorithm
     * @return
     */
    public RSA singAlgorithm(String algorithm){
        this.signAlgorithm = algorithm;
        return this;
    }

    /**
     *
     * @return [publickey, privateKey]
     * @throws NoSuchAlgorithmException
     */
    public String[] generateKey() throws Exception {
        KeyPairGenerator generator = KeyPairGenerator.getInstance(this.algorithm, this.keyProvider);
        generator.initialize(this.keySize);
        KeyPair pair = generator.generateKeyPair();
        PublicKey publicKey = pair.getPublic();
        byte[] pubKey = publicKey.getEncoded();
        PrivateKey privateKey = pair.getPrivate();
        byte[] prvkey = privateKey.getEncoded();

        String[] result = new String[] { Base64.getEncoder().encodeToString(pubKey),
                Base64.getEncoder().encodeToString(prvkey) };

        return result;
    }

    /**
     * 数据编码
     * @param data
     * @param key  publicKey
     * @return
     * @throws Exception
     */
    public byte[] decode(byte[] data, Key key) throws Exception {
        Cipher cipher = Cipher.getInstance(key.getAlgorithm()==null?this.algorithm:key.getAlgorithm(), this.decodeprovider);
        cipher.init(Cipher.DECRYPT_MODE, key);
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        int segSize = this.keySize / 8;

        segment(data, 0, segSize, (f, t)->{
            try {
                cipher.update(Arrays.copyOfRange(data, f, t+1));
                out.write(cipher.doFinal());
            } catch (Exception e) {
                e.printStackTrace();
            }
        });

        return out.toByteArray();
    }

    /**
     * 数据解码
     * @param data
     * @param key  publicKey
     * @return
     * @throws Exception
     */
    public byte[] decode(byte[] data, String key)throws Exception{
        return decode(data, decodePublicKey(key));
    }

    /**
     * 数据解码
     * @param data
     * @param key  privateKey
     * @return
     * @throws Exception
     */
    public byte[] decodeByPrivateKey(byte[] data, String key)throws Exception{
        return decode(data, decodePrivateKey(key));
    }

    /**
     * 加密
     * @param data
     * @param key privateKey
     * @return
     * @throws Exception
     */
    public byte[] encode(byte[] data, Key key) throws Exception{
        KeyFactory factory = KeyFactory.getInstance(this.algorithm, this.encodeProvider);
        Cipher cipher = Cipher.getInstance(factory.getAlgorithm());
        cipher.init(Cipher.ENCRYPT_MODE, key);
        ByteArrayOutputStream out = new ByteArrayOutputStream();

        int segSize = 117;
        segment(data, 0, segSize, (f, t)->{
            try {
                cipher.update(Arrays.copyOfRange(data, f, t+1));
                out.write(cipher.doFinal());
            } catch (Exception e) {
                e.printStackTrace();
            }
        });

        return out.toByteArray();
    }

    public static void  segment(byte[] data, int seed, int segSize, SegSeperator consumer){
        long size = data.length / segSize +((data.length%segSize==0?0:1));
        Stream.iterate(seed, k->k+segSize).limit(size).forEach(i->{
            int end = i + segSize -1;
            if(end >= data.length){
                end = data.length - 1;
            }
            try {
                consumer.deal(i, end);
            } catch (Exception e) {
                e.printStackTrace();
            }
        });
    }

    public interface SegSeperator{
        public void deal(int from, int to);
    }

    /**
     * 加密
     * @param data
     * @param key  privateKey
     * @return
     * @throws Exception
     */
    public byte[] encode(byte[] data, String key)throws Exception{
        return encode(data, decodePrivateKey(key));
    }

    /**
     * 加密
     * @param data
     * @param key  publicKey
     * @return
     * @throws Exception
     */
    public byte[] encodeByPublicKey(byte[] data, String key)throws Exception{
        return encode(data, decodePublicKey(key));
    }

    public RSAPublicKey decodePublicKey(String publicKey) throws Exception{
        byte[] keyBytes = Base64.getDecoder().decode(publicKey);
        X509EncodedKeySpec spec = new X509EncodedKeySpec(keyBytes);
        KeyFactory keyFactory = KeyFactory.getInstance(this.algorithm);
        return (RSAPublicKey) keyFactory.generatePublic(spec);
    }

    public RSAPrivateKey decodePrivateKey(String privateKey) throws Exception{
        byte[] keyBytes = Base64.getDecoder().decode(privateKey);
        PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(keyBytes);
        KeyFactory keyFactory = KeyFactory.getInstance(this.algorithm);
        RSAPrivateKey key = (RSAPrivateKey) keyFactory.generatePrivate(spec);
        return key;
    }

    /**
     * 数据签名
     * @param data
     * @param key
     * @return
     * @throws Exception
     */
    public byte[] sign(byte[] data, PrivateKey key) throws Exception {
        Signature sign = Signature.getInstance(this.signAlgorithm);
        sign.initSign(key);
        sign.update(data);
        return sign.sign();
    }

    /**
     * 签名
     * @param data：待签名的数据
     * @param privateKey：私钥
     * @return
     * @throws Exception
     */
    public byte[] sign(byte[] data, String privateKey) throws Exception {
        return sign(data, decodePrivateKey(privateKey));
    }

    /**
     * 签名校验
     * @param data： 数据
     * @param key： 公钥
     * @param sign: 签名
     * @return
     * @throws Exception
     */
    public boolean validate(byte[] data, PublicKey key, byte[] sign) throws Exception {
        Signature signature = Signature.getInstance(this.signAlgorithm);
        signature.initVerify(key);
        signature.update(data);
        return signature.verify(sign);
    }

    /**
     * 校验签名
     * @param data：签名数据
     * @param publicKey：公钥
     * @param base64sign：base 64位的字符串签名
     * @return
     * @throws Exception
     */
    public boolean validate(byte[] data, String publicKey, String base64sign) throws Exception {
        return validate(data, decodePublicKey(publicKey), Base64.getDecoder().decode(base64sign));
    }

    /**
     * Byte used to pad output.
     */
    protected static final byte PAD_DEFAULT = '='; // Allow static access to default

    /**
     * This array is a lookup table that translates Unicode characters drawn from the "Base64 Alphabet" (as specified
     * in Table 1 of RFC 2045) into their 6-bit positive integer equivalents. Characters that are not in the Base64
     * alphabet but fall within the bounds of the array are translated to -1.
     *
     * Note: '+' and '-' both decode to 62. '/' and '_' both decode to 63. This means decoder seamlessly handles both
     * URL_SAFE and STANDARD base64. (The encoder, on the other hand, needs to know ahead of time what to emit).
     *
     * Thanks to "commons" project in ws.apache.org for this code.
     * http://svn.apache.org/repos/asf/webservices/commons/trunk/modules/util/
     */
    private static final byte[] DECODE_TABLE = { -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
            -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1,
            62, -1, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8,
            9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, 63, -1, 26, 27, 28, 29,
            30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51 };

    /**
     * Returns whether or not the <code>octet</code> is in the base 64 alphabet.
     *
     * @param octet
     *            The value to test
     * @return <code>true</code> if the value is defined in the the base 64 alphabet, <code>false</code> otherwise.
     * @since 1.4
     */
    public static boolean isBase64(final byte octet) {
        return octet == PAD_DEFAULT || (octet >= 0 && octet < DECODE_TABLE.length && DECODE_TABLE[octet] != -1);
    }

    /**
     * Tests a given String to see if it contains only valid characters within the Base64 alphabet. Currently the
     * method treats whitespace as valid.
     *
     * @param base64
     *            String to test
     * @return <code>true</code> if all characters in the String are valid characters in the Base64 alphabet or if
     *         the String is empty; <code>false</code>, otherwise
     *  @since 1.5
     */
    public static boolean isBase64(final String base64) throws Exception {
        return isBase64(base64.getBytes("UTF-8"));
    }

    /**
     * Tests a given byte array to see if it contains only valid characters within the Base64 alphabet. Currently the
     * method treats whitespace as valid.
     *
     * @param arrayOctet
     *            byte array to test
     * @return <code>true</code> if all bytes are valid characters in the Base64 alphabet or if the byte array is empty;
     *         <code>false</code>, otherwise
     * @since 1.5
     */
    public static boolean isBase64(final byte[] arrayOctet) {
        for (int i = 0; i < arrayOctet.length; i++) {
            if (!isBase64(arrayOctet[i]) && !isWhiteSpace(arrayOctet[i])) {
                return false;
            }
        }
        return true;
    }

    /**
     * Checks if a byte value is whitespace or not.
     * Whitespace is taken to mean: space, tab, CR, LF
     * @param byteToCheck
     *            the byte to check
     * @return true if byte is whitespace, false otherwise
     */
    protected static boolean isWhiteSpace(final byte byteToCheck) {
        switch (byteToCheck) {
            case ' ' :
            case '\n' :
            case '\r' :
            case '\t' :
                return true;
            default :
                return false;
        }
    }
}